Image classification with Google's ViT and Hugging Face 🤗
Fine-tuning a ViT model with Hugging Face.
🔖 Learning objectives
- In A brief introduction to Vision Transfomer (ViT) we will explore the core ideas that define a vision transformer model.
- In Loading the data we will take advantage of the Hugging Face Hub and its libraries to download a cool dataset.
- In EDA with pandas we will use pandas to explore the dataset.
- In Image pre-processing we will learn how to pre-process the images for training.
- In Training we will fine-tune a pre-trained vision transformer to perform a down-stream task.
- In Evaluation we will see how the model peforms after fine-tuning.
- In Data augmentation we will add extra transformations to the pre-processing pipeline to make our model more robust to unseen data.
- Acknowledgments contains the main resources I used as a source of inspiration and learning to write this post.
- Additional resources contains extra materials that I highly recommend to check out.
The Vision Transformer (ViT) is a transformer-based architecture for computer vision tasks proposed in the An Image is Worth 16x16 Words paper by researchers at Google in 2020. The ViT leverages only the encoder part of the original Transformer architecture and each input image is interpreted as a sequence of patch embeddings.
In the pre-processing phase, an image is broken down into a grid of square patches. Each patch is flattened and linearly projected onto a lower-dimensional space. A learnable position embedding is then added to the resulting patch embedding to retain the positional information. These steps are applied at once to each patch forming the image.
Before feeding the sequence of patch embeddings and position embeddings to the encoder, an extra learnable "class" embedding is prepended to the sequence. The last hidden state of this embedding serves as representation of the overall image. A classification head on top of the transformer encoder maps the last hidden state of the class embedding onto the final vector space representing the labels.
In this notebook we will learn how to take a ViT model pre-trained on a large dataset and fine-tune it to perform image classification on a smaller task-specific dataset. We will use the google/vit-base-patch16-224-in21k as a pre-trained model. This checkpoint was generated by pre-training a ViT on ImageNet-21k at a resolution of 224x224. During pre-processing images were converted into sequences of 16x16 patches. And as for the data, we will use the Matthijs/snacks dataset which consists of 6745 images and 20 classes.
We make sure our instance has a GPU and pip install the Hugging Face libraries.
!nvidia-smi
!pip install datasets -Uqq
!pip install transformers[sentencepiece] -Uqq
import numpy as np
import random
import torch
def set_seeds(seed=1234):
"""Set seeds for reproducibility."""
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
set_seeds(2077)
We can conveniently get access to the dataset using the load_dataset function from the datasets library.
The resulting DatasetDict object already contains a train, validation and test set as Dataset objects. Each dataset has two features: image and label.
Images are PIL.JpegImagePlugin.JpegImageFile objects, whereas labels are simple integers.
We can inspect the content of a Dataset as if we were inspecting a normal list.
from datasets import load_dataset
raw_data = load_dataset('Matthijs/snacks')
raw_data
raw_data['train'][:3]
We can get the full list of labels through the features attribute. They are represented as a ClassLabel object.
Each of the 20 labels represent a kind of snack... Some are healthier than others. 🍭
labels = raw_data['train'].features['label']
labels
Let's visualize some of the images with matplotlib to get a better understanding of the data. The ClassLabel class has a int2str method that allows us to map the IDs to the actual labels.
import matplotlib.pyplot as plt
nrows = 2
ncols = 5
plt.figure(figsize=(ncols*2, nrows*2), dpi=100)
for i in range(1, nrows*ncols + 1):
sample_image = raw_data['train'][i]
plt.subplot(nrows, ncols, i)
plt.imshow(sample_image['image'].resize((128, 128)))
plt.title(labels.int2str(sample_image['label']), fontsize=10)
plt.axis('off')
plt.show()
Probably the reason why we visualized images representing only a single class is because the data is ordered by labels.
Let's plot some images once again, but this time using a random sample.
random_sample = raw_data['train'].shuffle(seed=2077).select(range(100))
nrows = 2
ncols = 5
plt.figure(figsize=(ncols*2, nrows*2), dpi=100)
for i in range(nrows*ncols):
sample_image = random_sample[i]
plt.subplot(nrows, ncols, i+1)
plt.imshow(sample_image['image'].resize((128, 128)))
plt.title(labels.int2str(sample_image['label']), fontsize=10)
plt.axis('off')
plt.show()
Now we're talking!
With Hugging Face DatasetDict and Dataset classes we can convert the data into a pandas.DataFrame object to investigate its content leveraging our knowledge of the pandas library.
We just need to change the output format with the set_format method and grab the entire dataset by slicing the entire dataset with [:].
raw_data.set_format('pandas')
raw_data_pd = raw_data['train'][:]
raw_data_pd.shape
We count the number of images available for each label to check if there is any imbalance. Then we visualize the result as a barplot using seaborn.
samples_count = (
raw_data_pd
.groupby('label')
.count()
.reset_index()
.rename(columns={'image': 'count'})
# .reset_index()
)
# map IDs to labels
samples_count['label'] = samples_count['label'].map(lambda x: labels.int2str(x))
samples_count
import seaborn as sns
sns.set_theme(style='whitegrid', palette='pastel')
plt.figure(dpi=100)
bar = sns.barplot(y='label', x='count', data=samples_count, orient='h')
plt.show()
As we can observe from the table and bar plot above most classes are represented by about 250 images. pineapple is the most common class with 260 images, whereas pretzel is the least represented class with only 154 images.
Overall, we can say this dataset is pretty well balanced.
Before moving to the next section, let's make sure to restore the original format of the data with the reset_format method.
raw_data.reset_format()
As for text, images also require some sort of pre-processing before they can be fed to a ViT. The most basic approach just requires resizing and normalizing the images across the channels with the same statistics used during pre-training.
We will see how to increase the robustness of our model with the help of data augmentation in a section below.
The Hugging Face library provides a ViTFeatureExtractor class that automatically performs resizing and normalization according to the pre-training statistics.
To load the pre-processing pipeline, sometimes referred as feature extractor, we simply pass the google/vit-base-patch16-224-in21k checkpoint as an argument to the from_pretrained method when instantiating the ViTFeatureExtractor object.
It specifies that images need to be resized at a resolution of 224x224 and normalized across channels using image_mean and image_std statistics.
from transformers import ViTFeatureExtractor
checkpoint = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(checkpoint)
feature_extractor
We can test the feature extractor by simply passing a sample image as an argument.
What we get in return is a dictionary with a pixel_values key and associated value is a list that contains a numpy.ndarray of shape 3x224x224 representing the pre-processed image.
sample_image = random_sample['image'][0]
sample_image
sample_image_features = feature_extractor(sample_image)
sample_image_features
sample_image_features['pixel_values'][0].shape
We can visually inspect the pre-processed image by first converting the numpy.ndarray into a torch.Tensor and then convert the tensor into a PIL.Image.Image with the help of the transforms.ToPILImage() from the torchvision library.
from torchvision import transforms
import torch
transforms.ToPILImage()(torch.tensor(sample_image_features['pixel_values'][0]))
Now that we have a basic grasp of how the ViT feature extractor works, we can define a function that extract the pixel_values from each data point and apply this function on-the-fly on batches with the set_transform method everytime __getitem__ is called.
We can use the deepcopy function from the copy library to duplicate the datasets and keep the original ones unaltered.
def apply_feature_extractor(examples):
examples['pixel_values'] = [torch.tensor(feature_extractor(image.convert('RGB'))['pixel_values'][0]) for image in examples['image']]
return examples
import copy
data_preprocessed = copy.deepcopy(raw_data)
data_preprocessed.set_transform(apply_feature_extractor)
data_preprocessed
To see the function at work we simply need to access a sample.
data_preprocessed['train'][0]
It seems to work properly. Everytime one or more samples are accessed, the pre-processing function is invoked to generate the pixel_values which are then added to the features.
Just to be 100% sure the pre-processing transformation is working in the right way, let's visualize a couple of images.
nrows = 2
ncols = 5
plt.figure(figsize=(ncols*2, nrows*2), dpi=100)
preprocessed_samples = data_preprocessed['train'].shuffle(seed=2077).select(range(nrows * ncols))
for i in range(nrows*ncols):
preprocessed_sample = preprocessed_samples[i]
plt.subplot(nrows, ncols, i+1)
plt.imshow(transforms.ToPILImage()(torch.tensor(preprocessed_sample['pixel_values'])).resize((128, 128)))
plt.title(labels.int2str(preprocessed_sample['label']), fontsize=10)
plt.axis('off')
plt.show()
It's looking good!
Before moving to the training phase, there is one last step we have to do: rename the label feature as labels because by default the model will look for this specific column to compute the loss.
data_cleaned = data_preprocessed.rename_column('label', 'labels')
data_cleaned
We can get access to a pre-trained ViT model using the ViTForImageClassification class and use its from_pretrained method to load a pre-existing checkpoint.
In addition to the checkpoint, we also need to set num_labels to correctly represent the number of classes in our dataset.
In this way, during initialization the ViTForImageClassification class will replace the default head with a new one intended for our down-stream dataset.
For the sake of completeness, in this specific example we also define the id2label and label2id parameters to map IDs to labels and vice versa, but they are totally optional and they are not as relevant as num_labels.
from transformers import ViTForImageClassification
model = ViTForImageClassification.from_pretrained(
checkpoint,
num_labels=labels.num_classes,
id2label={index: label for index, label in enumerate(labels.names)},
label2id={label: index for index, label in enumerate(labels.names)}
)
By default the training loop reports back only the training and validation loss which from a human perspective can be difficult to interpret.
For this reason, we define a function to compute some of the metrics commonly used in a classification task: accuracy, precision, recall and f1 score.
They are all available through the load_metric function from the datasets library.
We define a function that at each evaluation step will take the validation predictions as input and will report back the metrics we chose.
import numpy as np
from datasets import load_metric
# define function to compute metrics
def compute_metrics_fn(eval_preds):
metrics = dict()
accuracy_metric = load_metric('accuracy')
precision_metric = load_metric('precision')
recall_metric = load_metric('recall')
f1_metric = load_metric('f1')
logits = eval_preds.predictions
labels = eval_preds.label_ids
preds = np.argmax(logits, axis=-1)
metrics.update(accuracy_metric.compute(predictions=preds, references=labels))
metrics.update(precision_metric.compute(predictions=preds, references=labels, average='weighted'))
metrics.update(recall_metric.compute(predictions=preds, references=labels, average='weighted'))
metrics.update(f1_metric.compute(predictions=preds, references=labels, average='weighted'))
return metrics
The TrainingArguments class from the transformers library is responsible for gathering up all the parameters to customize the training loop. The only required parameter is an output_dir to save predictions and checkpoints, but in this case we also specified parameters like the number of epochs, the batch size and how frequently we want to evaluate the model.
By default the remove_unused_columns is set to False so that during training the model automatically drops unwanted features like the image items in our dataset. However, we can't drop this feature yet because the set_transform method and the apply_feature_extractor function we defined earlier require the raw images to generate the pixel_values feature.
That's why we need to explicitly set this parameter to True.
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir='vit-run#0001',
num_train_epochs=5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
save_steps=200,
logging_steps=200,
evaluation_strategy='steps',
eval_steps=200,
load_best_model_at_end=True,
remove_unused_columns=False,
fp16=True
)
The data collator is that piece of code responsible to assemble together individual samples into batches.
Since we instructed the training loop to keep all the features, we can't rely on the DefaultDataCollator from the transformers library anymore. During training it would attempt to feed to the model also the original images, but as we well know at this point, the ViT model only accepts pixel_values and labels as input features.
We can overcome the issue by defining our own collate function. During training and inference the collate function receives as input a list of data points. The length of the list is defined by the batch size.
In this specific example, a data point corresponds to a dictionary with three key-value pairs: image, labels and pixel_values. Knowing that, we define our custom collate function so that it only keeps pixel_values and labels and ignores image.
pixel_values are 3D (channels, height, width) tensors and therefore we bundle them together with torch.stack to form a 4D tensor representing batch size, channels, height and width. labels are just integers and it's sufficient to convert the list containing them into a torch.tensor.
# e.g. examples = [{'image': , 'labels': , 'pixel_values': }]
def collate_fn(examples):
pixel_values = torch.stack([example['pixel_values'] for example in examples])
labels = torch.tensor([example['labels'] for example in examples])
return {'pixel_values': pixel_values, 'labels': labels}
We can test the custom collate function by providing a list of examples.
In this case we assume that a batch includes 8 samples.
collate_test = collate_fn([data_cleaned['train'][i] for i in range(8)])
print(collate_test['pixel_values'].shape)
print(collate_test['labels'])
It returned what we were expecting and so we're good to move on!
The Trainer class from the transformers library is the training loop itself. We provide all the items we've generated so far:
- pre-trained model
- training arguments
- custom collate function
- training and validation dataset
- custom metrics function
from transformers import Trainer
trainer = Trainer(
model,
training_args,
data_collator=collate_fn,
train_dataset=data_cleaned['train'],
eval_dataset=data_cleaned['validation'],
compute_metrics=compute_metrics_fn
)
The train method starts the training process.
💡 It is completely fine to get different numbers from one iteration to another.
trainer.train()